#!/usr/bin/env python3
import os
print('pypath', os.getenv('PYTHONPATH'))
print('foo', os.getenv('FOO'))
from pathlib import Path
import json
from rpi.scripts.pretraining.experts import (
    cheetah_ppo,
    cheetah_sac,
    walker_ppo,
    walker_sac,
    pendulum_ppo,
    pendulum_sac,
    cartpole_ppo,
    cartpole_sac,
    minigrid_empty8x8_ppo,
)


def convert(model_infos):
    return [(minfo['policy'], minfo['path']) for minfo in model_infos]


cheetah_ppo = convert(cheetah_ppo)
cheetah_sac = convert(cheetah_sac)
walker_ppo = convert(walker_ppo)
walker_sac = convert(walker_sac)
pendulum_ppo = convert(pendulum_ppo)
pendulum_sac = convert(pendulum_sac)
cartpole_ppo = convert(cartpole_ppo)
cartpole_sac = convert(cartpole_sac)
minigrid_empty8x8_ppo = convert(minigrid_empty8x8_ppo)


defaults = {
    "lmd": 0.9,
    "reset_expert_vfn": False,
    "use_ppo_loss": True,
    "std_from_means": True,
    "use_expert_obsnormalizer": True,
    "state_pred_num_epochs": 8,
    "deterministic_experts": True,  # Originally: False
    "learner_buffer_size": 2048,
    "expert_buffer_size": 19200,
    "max_episode_len": 1000,
    "num_epochs": 2,
    "num_rollouts": 8,
    "expert_tgtval": "gae",
    "gamma": 0.995,
    "pret_num_epochs": 32,
    "pret_num_rollouts": 8,
    "pret_num_val_iterations": 32,
}


def create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds, pggae=False):
    # mamba and lops-aps
    lines = []
    for seed in seeds:
        for env in envs:
            env_name = f'minigrid:{env}'
            print('env', env, 'env_name', env_name)
            for experts in env2experts_list[env]:
                for learner_pi, algorithm in zip(learner_pis, algorithms):
                    ase_sigma = env2ase_sigma[env] if algorithm == "lops-aps-ase" else 0.
                    lines.append(
                        {
                            "env_name": env_name,
                            "load_expert_step": experts,
                            "experts_info": experts,
                            "algorithm": algorithm,
                            "use_riro_for_learner_pi": learner_pi,
                            "ase_sigma": ase_sigma,
                            "seed": seed,
                            **defaults
                        }
                    )
            if pggae:
                # pg-gae
                expert_paths = {
                    'MiniGrid-Empty-8x8': minigrid_empty8x8_ppo[:1],
                    # 'cheetah-run': cheetah_ppo[:1],
                    # 'cartpole-swingup': cartpole_ppo[:1],
                    # 'walker-walk': walker_ppo[:1],
                    # 'pendulum-swingup': pendulum_ppo[:1]
                }[env]

                lines.append(
                    {
                        "env_name": env_name,
                        "load_expert_step": [0],
                        "experts_info": expert_paths,
                        "algorithm": "pg-gae",
                        "use_riro_for_learner_pi": "none",
                        "ase_sigma": 0,
                        "seed": seed,
                        **defaults
                    }
                )

    json_text = [json.dumps(line, sort_keys=True) for line in lines]
    print(f'{len(json_text)} lines to {fname}')
    with open(fname, 'w') as f:
        f.write('\n'.join(json_text))


if __name__ == '__main__':
    import sys
    this_file_name = sys.argv[0]



    # Variables to sweep over
    # envs = ['Cheetah-run', 'Walker-walk', 'Pendulum-swingup', 'Cartpole-swingup']
    envs = ['MiniGrid-Empty-8x8']

    env2experts_list = {
        'MiniGrid-Empty-8x8': [minigrid_empty8x8_ppo[::5][:3]], 
        # 'Cheetah-run': [cheetah_ppo[:3], cheetah_ppo[::4][:3], cheetah_sac[:3], cheetah_sac[::4][:3], cheetah_sac[-3:]],
        # 'Walker-walk': [walker_ppo[:3], walker_ppo[::4][:3], walker_sac[:3], walker_sac[::4][:3], walker_sac[-3:]],
        # 'Pendulum-swingup': [pendulum_ppo[:3], pendulum_ppo[::4][:3], pendulum_sac[:3], pendulum_sac[-3:]],
        # 'Cartpole-swingup': [cartpole_ppo[:3], cartpole_ppo[::4][:3], cartpole_sac[:3], cartpole_sac[::4][:3], cartpole_sac[-3:]]
    }
    env2ase_sigma = {
        'MiniGrid-Empty-8x8': 0.25, 
        # 'Cheetah-run': 2.5,
        # 'Walker-walk': 10,
        # 'Pendulum-swingup': 0.25,
        # 'Cartpole-swingup': 0.25,  # <-- We should run a sweep to find out a good value for this
    }

    seeds = [i for i in range(5)]
    learner_pis = ['none', 'all']
    # algorithms = ['mamba', 'lops-aps', 'lops-aps-ase']
    algorithms = ['mamba', 'lops-aps']

    fname = Path(this_file_name).stem + '.jsonl'
    create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds, pggae=True)


